{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Saving and Loading Models\n",
    "\n",
    "When it comes to saving and loading models, there are 3 core functions:\n",
    "\n",
    "- ***torch.save***: Saves a serialized object to disk using Python's **pickle** utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved.\n",
    "\n",
    "- ***torch.load***: Uses pickle's uppickling facilities to deserialize pickled object files to memory.\n",
    "\n",
    "- ***torch.nn.Module.load_state_dict***: Loads a model's parameter dictionary using a deserialized state_dict.\n",
    "\n",
    "## 1. What is state_dict?\n",
    "\n",
    "In PyTorch, the learnable parameters (i.e. weights and biases) of an *torch.nn.Module* model are contained in the model's parameters (accessed with *model.parameters()*). A *state_dict* is simply a Python dictionary object that maps each layer to its parameter tensor.\n",
    "\n",
    "Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers have entries in the model's *state_dict*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TheModelClass(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(TheModelClass, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 6, 5)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * 5 * 5)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TheModelClass()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model's state_dict: \n",
      "conv1.weight \t torch.Size([6, 3, 5, 5])\n",
      "conv1.bias \t torch.Size([6])\n",
      "conv2.weight \t torch.Size([16, 6, 5, 5])\n",
      "conv2.bias \t torch.Size([16])\n",
      "fc1.weight \t torch.Size([120, 400])\n",
      "fc1.bias \t torch.Size([120])\n",
      "fc2.weight \t torch.Size([84, 120])\n",
      "fc2.bias \t torch.Size([84])\n",
      "fc3.weight \t torch.Size([10, 84])\n",
      "fc3.bias \t torch.Size([10])\n"
     ]
    }
   ],
   "source": [
    "print(\"Model's state_dict: \")\n",
    "for param_tensor in model.state_dict():\n",
    "    print(param_tensor, '\\t', model.state_dict()[param_tensor].size())  # conv1,conv2,fc1,fc2,fc3的weight和bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimizer's state_dict:\n",
      "state \t {}\n",
      "param_groups \t [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4753872288, 4753873296, 4753874736, 4753873152, 4753874160, 4753873368, 4753874376, 4753874520, 4753874232, 4753874880]}]\n"
     ]
    }
   ],
   "source": [
    "print(\"Optimizer's state_dict:\")\n",
    "for var_name in optimizer.state_dict():\n",
    "    print(var_name, '\\t', optimizer.state_dict()[var_name])  # state和param_groups"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Saving & Loading Model for Inference\n",
    "\n",
    "### Save/Load state_dict (Recommended)\n",
    "\n",
    "***torch.save()*** function gives you the most flexibility for restoring the model later, which is why it's the recommended method for saving models.\n",
    "\n",
    "A common PyTorch convention is to save models using either a ***.pt*** or ***.pth*** file extension.\n",
    "\n",
    "Remember that you must call ***model.eval()*** to set **dropout and batch normalization layers** to evaluation mode before running inference. Failing to do this will yield **inconsistent inference result**.\n",
    "\n",
    "#### Save:\n",
    "\n",
    "```\n",
    "torch.save(model.state_dict(), PATH)\n",
    "```\n",
    "\n",
    "#### Load:\n",
    "\n",
    "```\n",
    "model = TheModelClass(*args, **kwars)\n",
    "model.load_state_dict(torch.load(PATH))\n",
    "model.eval\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save/Load Entire Model\n",
    "\n",
    "Saving a model in this way will save the entire module using Python's *pickle* module.\n",
    "\n",
    "The disadvantage of this approach is that the serialized data is **bound to the specific classes and the exact directory structure** used when the model is saved. The reason for this is bacause pickle doesn't save the model class itself. Rather, it saves a path to the file containing the class.\n",
    "\n",
    "#### Save:\n",
    "\n",
    "```\n",
    "torch.save(model, PATH)\n",
    "```\n",
    "\n",
    "#### Loads:\n",
    "\n",
    "```\n",
    "model = torch.load(PATH)  # Model class must be defined somewhere\n",
    "model.eval()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Saving & Loading a General Checkpoint for Inference and/or Resuming Training\n",
    "\n",
    "When saving a general checkpoint, to be used for either inference or resuming training, you must save more than just **the model's state_dict**. It's important to alse save **the optimizer's state_dict**, as this contains buffers and parameters that are updated as the model trains. Other items that you may want to save are **epoch you left off on, the latest recorded training loss, external torch.nn.Embedding layer**, etc.\n",
    "\n",
    "A common PyTorch convention is to save these checkpoints using the ***.tar*** file extension.\n",
    "\n",
    "#### Save:\n",
    "\n",
    "```\n",
    "torch.save({\n",
    "    'epoch': epoch,\n",
    "    'model_state_dict': model.state_dict(),\n",
    "    'optimizer_state_dict': optimizer.state_dict(),\n",
    "    'loss': loss,\n",
    "    ...\n",
    "    }, PATH)\n",
    "```\n",
    "\n",
    "#### Load:\n",
    "\n",
    "```\n",
    "model = TheModelClass(*args, **kwargs)\n",
    "optimizer = TheOptimizerClass(*args, **kwargs)\n",
    "\n",
    "checkpoint = torch.load(PATH)\n",
    "model.load_state_dict(checkpoint['model_state_dict'])\n",
    "optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n",
    "epoch = checkpoint['epoch']\n",
    "loss = checkpoint['loss']\n",
    "\n",
    "model.eval()\n",
    "# - Or -\n",
    "model.train()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Saving Multiple Models in One File\n",
    "\n",
    "#### Save:\n",
    "\n",
    "```\n",
    "torch.save({\n",
    "            'modelA_state_dict': modelA.state_dict(),\n",
    "            'modelB_state_dict': modelB.state_dict(),\n",
    "            'optimizerA_state_dict': optimizerA.state_dict(),\n",
    "            'optimizerB_state_dict': optimizerB.state_dict(),\n",
    "            ...\n",
    "            }, PATH)\n",
    "```\n",
    "\n",
    "#### Load:\n",
    "\n",
    "```\n",
    "modelA = TheModelAClass(*args, **kwargs)\n",
    "modelB = TheModelBClass(*args, **kwargs)\n",
    "optimizerA = TheOptimizerAClass(*args, **kwargs)\n",
    "optimizerB = TheOptimizerBClass(*args, **kwargs)\n",
    "\n",
    "checkpoint = torch.load(PATH)\n",
    "modelA.load_state_dict(checkpoint['modelA_state_dict'])\n",
    "modelB.load_state_dict(checkpoint['modelB_state_dict'])\n",
    "optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])\n",
    "optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])\n",
    "\n",
    "modelA.eval()\n",
    "modelB.eval()\n",
    "# - or -\n",
    "modelA.train()\n",
    "modelB.train()\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Warmstarting Model Using Parameter Model\n",
    "\n",
    "Partially loading a model or loading a partial model are common scenarios when transfer learning or training a new complex model.\n",
    "\n",
    "Whether you are loading from ***a partial state_dict***, which is **missing some keys**, or loading a *state_dict* with **more keys than the model** that you are loading into, you can set the ***strict*** argument to ***False*** in the *load_state_dict()* function to ignore non-matching keys.\n",
    "\n",
    "If you want to load parameters from one layer to another, but some keys don't match, simply **change the name of the parameter keys** in the *state_dict** that you are loading to match the keys in the model that you are loading into.\n",
    "\n",
    "> 刘尧：可以**在不同结构和参数的model间互相load**! 只填充相同key的parameters，有时也可通过修改name来填充到指定的parameters！\n",
    "\n",
    "#### Save:\n",
    "\n",
    "```\n",
    "torch.save(modelA.state_dict(), PATH)\n",
    "```\n",
    "\n",
    "#### Load:\n",
    "\n",
    "```\n",
    "modelB = TheModelBClass(*args, **kwargs)\n",
    "modelB.load_state_dict(torch.load(PATH), strict=False)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Saving & Loading Model Across Devices\n",
    "\n",
    "> 刘尧：总结为一句话就是：Load时，跨Devices需要map_location，在GPU上需要model.to(device)和tensor.ot(device)\n",
    "\n",
    "#### Save:\n",
    "\n",
    "```\n",
    "torch.save(model.state_dict(), PATH)\n",
    "```\n",
    "\n",
    "### Save on GPU, Load on CPU\n",
    "\n",
    "#### Load:\n",
    "\n",
    "```\n",
    "device = torch.device('cpu')\n",
    "model = TheModelClass(*args, **kwargs)\n",
    "model.load_state_dict(torch.load(PATH, map_location=device))\n",
    "```\n",
    "\n",
    "### Save on GPU, Load on GPU\n",
    "\n",
    "#### Load:\n",
    "\n",
    "```\n",
    "device = torch.device(\"cuda\")\n",
    "model = TheModelClass(*args, **kwargs)\n",
    "model.load_state_dict(torch.load(PATH))\n",
    "model.to(device)\n",
    "# Make sure to call input = input.to(device) on any input tensors that you feed to the model\n",
    "```\n",
    "\n",
    "### Save on CPU, Load on GPU\n",
    "\n",
    "#### Load:\n",
    "\n",
    "```\n",
    "device = torch.device(\"cuda\")\n",
    "model = TheModelClass(*args, **kwargs)\n",
    "model.load_state_dict(torch.load(PATH, map_location=\"cuda:0\"))\n",
    "model.to(device)\n",
    "# Make sure to call input = input.to(device) on any input tensors that you feed to the model\n",
    "```\n",
    "\n",
    "### Saving torch.nn.DataParallel\n",
    "\n",
    "***torch.nn.DataParallel*** is a model wrapper that enables parallel GPU utilization. To save a *DataParallel* model generically, save the ***model.module.state_dict()***. This way, you have the flexibility to load the model any way you want to any device you want.\n",
    "\n",
    "#### Save:\n",
    "\n",
    "```\n",
    "torch.save(model.module.state_dict(), PATH)\n",
    "```\n",
    "\n",
    "#### Load:\n",
    "\n",
    "```\n",
    "# Load to whatever device you want\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
